import pandas as pd
import numpy as np
import statsmodels.api as sm
import matplotlib.pyplot as plt

# Load data
df = pd.read_csv("BoltsUTD_PME20-25_Repl.csv")

# Define predictors for the treatment effects
predictors = ["Human_Only", "Tactical_AI", "Strategic_AI"]

# Drop rows with NA in relevant columns
df = df.dropna(subset=predictors + ["Threat_Novelty_Index", "DV2"])

# Split into high and low threat novelty using median split
median_threat = df["Threat_Novelty_Index"].median()
df["Threat_Level"] = np.where(df["Threat_Novelty_Index"] > median_threat, "High", "Low")

# Function to run OLS and return coefficients and 95% CI
def run_ols_and_plot(data, outcome_var, label, color, offset):
    X = data[predictors]
    X = sm.add_constant(X)
    y = data[outcome_var]

    model = sm.OLS(y, X).fit()
    coefs = model.params.drop("const")
    ses = model.bse.drop("const")
    ci_err = 1.96 * ses

    # Plot
    x = np.arange(len(predictors))
    ax.errorbar(
        x + offset,
        coefs,
        yerr=ci_err,
        fmt='o',
        color=color,
        capsize=5,
        label=label
    )

# Plot both groups, DV1 and DV2
fig, ax = plt.subplots(figsize=(10, 5))

run_ols_and_plot(df[df["Threat_Level"] == "Low"], "DV1", "DV1 - Low Threat", "red", -0.2)
run_ols_and_plot(df[df["Threat_Level"] == "High"], "DV1", "DV1 - High Threat", "blue", -0.1)
run_ols_and_plot(df[df["Threat_Level"] == "Low"], "DV2", "DV2 - Low Threat", "darkred", 0.1)
run_ols_and_plot(df[df["Threat_Level"] == "High"], "DV2", "DV2 - High Threat", "darkblue", 0.2)

# Style plot
x = np.arange(len(predictors))
ax.axhline(0, color='gray', linestyle='dotted')
ax.set_xticks(x)
ax.set_xticklabels(predictors, fontsize=10)
ax.set_ylabel("OLS Coefficient")
ax.set_title("Treatment Effects on DV1 and DV2 (Split by Threat Novelty Level)")
ax.legend(loc='lower center', bbox_to_anchor=(0.5, -0.3), ncol=2, frameon=False)

for spine in ['top', 'right', 'left']:
    ax.spines[spine].set_visible(False)

ax.grid(axis='y', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.savefig("dv1_dv2_threat_split_effects.png", dpi=300)
plt.show()
